import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init
from mmcv.runner import auto_fp16

from ..builder import NECKS

import torch
eps=0.0001

@NECKS.register_module()
class BIFPN(nn.Module):
    r"""Feature Pyramid Network.

    This is an implementation of paper `Feature Pyramid Networks for Object
    Detection <https://arxiv.org/abs/1612.03144>`_.

    Args:
        in_channels (List[int]): Number of input channels per scale.
        out_channels (int): Number of output channels (used at each scale)
        num_outs (int): Number of output scales.
        start_level (int): Index of the start input backbone level used to
            build the feature pyramid. Default: 0.
        end_level (int): Index of the end input backbone level (exclusive) to
            build the feature pyramid. Default: -1, which means the last level.
        add_extra_convs (bool | str): If bool, it decides whether to add conv
            layers on top of the original feature maps. Default to False.
            If True, its actual mode is specified by `extra_convs_on_inputs`.
            If str, it specifies the source feature map of the extra convs.
            Only the following options are allowed

            - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
            - 'on_lateral':  Last feature map after lateral convs.
            - 'on_output': The last output feature map after fpn convs.
        extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
            on the original feature from the backbone. If True,
            it is equivalent to `add_extra_convs='on_input'`. If False, it is
            equivalent to set `add_extra_convs='on_output'`. Default to True.
        relu_before_extra_convs (bool): Whether to apply relu before the extra
            conv. Default: False.
        no_norm_on_lateral (bool): Whether to apply norm on lateral.
            Default: False.
        conv_cfg (dict): Config dict for convolution layer. Default: None.
        norm_cfg (dict): Config dict for normalization layer. Default: None.
        act_cfg (str): Config dict for activation layer in ConvModule.
            Default: None.
        upsample_cfg (dict): Config dict for interpolate layer.
            Default: `dict(mode='nearest')`

    Example:
        >>> import torch
        >>> in_channels = [2, 3, 5, 7]
        >>> scales = [340, 170, 84, 43]
        >>> inputs = [torch.rand(1, c, s, s)
        ...           for c, s in zip(in_channels, scales)]
        >>> self = FPN(in_channels, 11, len(in_channels)).eval()
        >>> outputs = self.forward(inputs)
        >>> for i in range(len(outputs)):
        ...     print(f'outputs[{i}].shape = {outputs[i].shape}')
        outputs[0].shape = torch.Size([1, 11, 340, 340])
        outputs[1].shape = torch.Size([1, 11, 170, 170])
        outputs[2].shape = torch.Size([1, 11, 84, 84])
        outputs[3].shape = torch.Size([1, 11, 43, 43])
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 num_outs,
                 start_level=0,
                 end_level=-1,
                 stack=1,
                 # add_extra_convs=False,
                 add_extra_convs=True,
                 # extra_convs_on_inputs=True,
                 extra_convs_on_inputs=False,
                 relu_before_extra_convs=False,
                 # no_norm_on_lateral=False,
                 no_norm_on_lateral=True,
                 conv_cfg=None,
                 # norm_cfg=None,
                 norm_cfg=dict(type='BN',requires_grad=False),
                 act_cfg=None,
                 # act_cfg='relu',  # notice here
                 upsample_cfg=dict(mode='nearest')):
        super(BIFPN, self).__init__()
        assert isinstance(in_channels, list)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_ins = len(in_channels)
        self.num_outs = num_outs

        self.act_cfg = act_cfg  # notice here

        self.relu_before_extra_convs = relu_before_extra_convs
        self.no_norm_on_lateral = no_norm_on_lateral
        self.fp16_enabled = False

        self.stack = stack  # notice here

        self.upsample_cfg = upsample_cfg.copy()

        if end_level == -1:
            self.backbone_end_level = self.num_ins
            assert num_outs >= self.num_ins - start_level
        else:
            # if end_level < inputs, no extra level is allowed
            self.backbone_end_level = end_level
            assert end_level <= len(in_channels)
            assert num_outs == end_level - start_level
        self.start_level = start_level
        print('start_level:', start_level)
        print('self.backbone_end_level:', self.backbone_end_level)

        self.end_level = end_level
        self.add_extra_convs = add_extra_convs
        self.extra_convs_on_inputs = extra_convs_on_inputs

        # assert isinstance(add_extra_convs, (str, bool))
        # if isinstance(add_extra_convs, str):
        #     # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
        #     assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
        # elif add_extra_convs:  # True
        #     if extra_convs_on_inputs:
        #         # For compatibility with previous release
        #         # TODO: deprecate `extra_convs_on_inputs`
        #         self.add_extra_convs = 'on_input'
        #     else:
        #         self.add_extra_convs = 'on_output'

        self.lateral_convs = nn.ModuleList()
        # self.fpn_convs = nn.ModuleList()

        self.stack_bifpn_convs = nn.ModuleList()

        if self.add_extra_convs:
            self.extra_convs = nn.ModuleList()

        for i in range(self.start_level, self.backbone_end_level):
            l_conv = ConvModule(
                in_channels[i],
                out_channels,
                1,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
                # act_cfg=act_cfg,
                act_cfg=None,  # notice
                inplace=False)
            # fpn_conv = ConvModule(
            #     out_channels,
            #     out_channels,
            #     3,
            #     padding=1,
            #     conv_cfg=conv_cfg,
            #     norm_cfg=norm_cfg,
            #     act_cfg=act_cfg,
            #     inplace=False)

            self.lateral_convs.append(l_conv)
            # self.fpn_convs.append(fpn_conv)

        # add extra conv layers (e.g., RetinaNet)
        # extra_levels = num_outs - self.backbone_end_level + self.start_level
        self.extra_levels = num_outs - self.backbone_end_level + self.start_level
        print("self.extra_levels:",self.extra_levels)

        if self.extra_levels > 0:  # notice here
            for i in range(self.extra_levels):
                in_channels = self.in_channels[self.backbone_end_level - 1]
                extra_l_conv = ConvModule(
                    in_channels,
                    out_channels,
                    1,
                    conv_cfg=conv_cfg,
                    # norm_cfg=norm_cfg,
                    norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
                    act_cfg=None,
                    inplace=False)
                self.lateral_convs.append(extra_l_conv)  # notice here

                # if self.add_extra_convs == 'on_input':
                if self.add_extra_convs:
                    extra_conv = ConvModule(
                        in_channels,
                        # out_channels,
                        in_channels,
                        3,
                        stride=2,
                        padding=1,
                        conv_cfg=conv_cfg,
                        norm_cfg=norm_cfg,
                        act_cfg=self.act_cfg,
                        # act_cfg=dict(type='relu', inplace=False),
                        inplace=False
                    )
                    self.extra_convs.append(extra_conv)

        # if self.add_extra_convs and extra_levels >= 1:
        #     for i in range(extra_levels):
        #         if i == 0 and self.add_extra_convs == 'on_input':
        #             in_channels = self.in_channels[self.backbone_end_level - 1]
        #         else:
        #             in_channels = out_channels
        #         extra_fpn_conv = ConvModule(
        #             in_channels,
        #             out_channels,
        #             3,
        #             stride=2,
        #             padding=1,
        #             conv_cfg=conv_cfg,
        #             norm_cfg=norm_cfg,
        #             act_cfg=act_cfg,
        #             inplace=False)
        #         self.fpn_convs.append(extra_fpn_conv)

        for ii in range(stack):
            self.stack_bifpn_convs.append(BiFPNModule(channels=out_channels,
                                                      levels=self.backbone_end_level-self.start_level+self.extra_levels,
                                                      conv_cfg=conv_cfg,
                                                      norm_cfg=norm_cfg,
                                                      act_cfg=act_cfg))

    # default init_weights for conv(msra) and norm in ConvModule
    # def init_weights(self):
    #     """Initialize the weights of FPN module."""
    #     for m in self.modules():
    #         if isinstance(m, nn.Conv2d):
    #             xavier_init(m, distribution='uniform')

    @auto_fp16()
    def forward(self, inputs):
        """Forward function."""
        assert len(inputs) == len(self.in_channels)

        inputs = list(inputs)  # notice
        print('shape of inputs[-1]:',inputs[-1].shape)

        # add extra
        if self.extra_levels > 0:
            for i in range(self.extra_levels):
                if self.add_extra_convs:
                    inputs.append(self.extra_convs[i](inputs[-1]))
                else:
                    inputs.append(F.max_pool2d(inputs[-1],1,stride=2))


        # build laterals
        laterals = [
            lateral_conv(inputs[i + self.start_level])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        for bifpn_module in self.stack_bifpn_convs:
            laterals = bifpn_module(laterals)
        outs = laterals
        return tuple(outs)

        # # build top-down path
        # used_backbone_levels = len(laterals)
        # for i in range(used_backbone_levels - 1, 0, -1):
        #     # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
        #     #  it cannot co-exist with `size` in `F.interpolate`.
        #     if 'scale_factor' in self.upsample_cfg:
        #         laterals[i - 1] += F.interpolate(laterals[i],
        #                                          **self.upsample_cfg)
        #     else:
        #         prev_shape = laterals[i - 1].shape[2:]
        #         laterals[i - 1] += F.interpolate(
        #             laterals[i], size=prev_shape, **self.upsample_cfg)

        # # build outputs
        # # part 1: from original levels
        # outs = [
        #     self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
        # ]
        # print("length of outs:", len(outs))
        # # part 2: add extra levels
        # if self.num_outs > len(outs):
        #     # use max pool to get more levels on top of outputs
        #     # (e.g., Faster R-CNN, Mask R-CNN)
        #     if not self.add_extra_convs:
        #         for i in range(self.num_outs - used_backbone_levels):
        #             outs.append(F.max_pool2d(outs[-1], 1, stride=2))
        #     # add conv layers on top of original feature maps (RetinaNet)
        #     else:
        #         if self.add_extra_convs == 'on_input':
        #             extra_source = inputs[self.backbone_end_level - 1]
        #         elif self.add_extra_convs == 'on_lateral':
        #             extra_source = laterals[-1]
        #         elif self.add_extra_convs == 'on_output':
        #             extra_source = outs[-1]
        #         else:
        #             raise NotImplementedError
        #         outs.append(self.fpn_convs[used_backbone_levels](extra_source))
        #         for i in range(used_backbone_levels + 1, self.num_outs):
        #             if self.relu_before_extra_convs:
        #                 outs.append(self.fpn_convs[i](F.relu(outs[-1])))
        #             else:
        #                 outs.append(self.fpn_convs[i](outs[-1]))
        # return tuple(outs)

    def init_weights(self):
        """Initialize the weights of FPN module."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform')

class BiFPNModule(nn.Module):
    def __init__(self,
                 channels,
                 levels,
                 init=0.5,
                 conv_cfg=None,
                 norm_cfg=None,
                 act_cfg=None):
        super(BiFPNModule,self).__init__()
        self.act_cfg = act_cfg
        self.levels = levels
        self.bifpn_convs = nn.ModuleList()
        #weighted
        self.w1 = nn.Parameter(torch.Tensor(2,levels).fill_(init))
        self.relu1 = nn.ReLU()
        self.w2 = nn.Parameter(torch.Tensor(3, levels-2).fill_(init))
        self.relu2 = nn.ReLU()
        for jj in range(2):
            for i in range(self.levels-1):  # 1,2,3
                fpn_conv = nn.Sequential(
                    ConvModule(
                        channels,
                        channels,
                        3,
                        padding=1,
                        groups=channels,
                        conv_cfg=conv_cfg,
                        norm_cfg=None,
                        act_cfg=None,
                        inplace=False
                    ),
                    ConvModule(
                        channels,
                        channels,
                        1,
                        conv_cfg=conv_cfg,
                        norm_cfg=norm_cfg,
                        act_cfg=self.act_cfg,
                        inplace=False
                    )
                )
                self.bifpn_convs.append(fpn_conv)

    def forward(self,inputs):
        assert len(inputs) == self.levels
        # build top-down and down-top path with stack
        levels = self.levels
        # w relu
        w1 = self.relu1(self.w1)
        w1 /= torch.sum(w1,dim=0)+eps  # normalize
        w2 = self.relu2(self.w2)
        w2 /= torch.sum(w2,dim=0)+eps
        # build top-down
        kk=0
        # pathtd = inputs copy is wrong
        pathtd=[inputs[levels-1]]
        # for in_tensor in inputs:
        #     pathtd.append(in_tensor.clone().detach())
        for i in range(levels-1,0,-1):
            _t = w1[0,kk]*inputs[i-1]+w1[1,kk]*F.interpolate(
                pathtd[-1], scale_factor=2,mode='nearest'
            )
            pathtd.append(self.bifpn_convs[kk](_t))
            del(_t)
            kk = kk+1
        jj = kk
        pathtd = pathtd[::-1]
        # build down-top
        for i in range(0,levels-2,1):
            pathtd[i+1] = w2[0,i]*inputs[i+1]+w2[1,i]*nn.Upsample(scale_factor=0.5)(pathtd[i])+w2[2,i]*\
                                                                                    pathtd[i+1]
            pathtd[i+1]=self.bifpn_convs[jj](pathtd[i+1])
            jj = jj+1

        pathtd[levels-1] = w1[0,kk]*inputs[levels-1]+w1[1,kk]*nn.Upsample(scale_factor=0.5)(pathtd[levels-2])
        pathtd[levels-1] = self.bifpn_convs[jj](pathtd[levels-1])

        return pathtd

    def init_weights(self):
        """Initialize the weights of FPN module."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform')